import os
import csv
import random

def split_by_ratio(data,ratios):
    train_path="train.csv"
    validate_path="validate.csv"
    test_path="test.csv"
    other_path="other.csv"

    # header=("filepath","label")

    train_file=open(train_path, mode='w', newline='', encoding='utf-8')
    train_writer=csv.writer(train_file)
    validate_file=open(validate_path, mode='w', newline='', encoding='utf-8')
    validate_writer=csv.writer(validate_file)
    test_file=open(test_path, mode='w', newline='', encoding='utf-8')
    test_writer=csv.writer(test_file)
    other_file=open(other_path, mode='w', newline='', encoding='utf-8')
    other_writer=csv.writer(other_file)

    # train_writer.writerow(header)
    # validate_writer.writerow(header)
    # test_writer.writerow(header)
    # other_writer.writerow(header)

    for item in data:
        n_train=int(ratios[0]*len(item))
        n_validate=int(ratios[1]*len(item))
        n_test=int(ratios[2]*len(item))
        random.shuffle(item)
        train_writer.writerows(item[:n_train+1])
        validate_writer.writerows(item[n_train+1:n_train+n_validate+1])
        test_writer.writerows(item[n_train+n_validate+1:n_train+n_validate+n_test+1])
        other_writer.writerows(item[n_train+n_validate+n_test+1:])

    train_file.close()
    validate_file.close()
    test_file.close()
    other_file.close()


def split_files_by_ratio(folder_root, ratios):
    """
    遍历指定文件夹中的所有子文件夹，根据给定的比例分割文件，并生成一个CSV文件，包含文件名和对应的子类别。

    :param folder_path: 包含类别文件夹的父文件夹路径
    :param ratios: 一个列表，表示每个类别的分割比例，例如 [0.6, 0.2, 0.2]
    :param csv_file_name: 生成的CSV文件名，默认为 'file_category.csv'
    """
    data=[]
    idx=0
    for category in os.listdir(folder_root):
        category_path = os.path.join(folder_root, category)
        if os.path.isdir(category_path):
            # 获取类别文件夹中的所有文件
            items = [(f"{category_path}/{f}",idx) for f in os.listdir(category_path) if os.path.isfile(os.path.join(category_path, f))]
            data.append(items)
            idx+=1
    split_by_ratio(data,ratios)

def main():
    folder_path = 'datasets/Vegetable Images/validation'  # 替换为你的文件夹路径
    ratios = [0,0,1]  # 比例列表
    split_files_by_ratio(folder_path, ratios)

if __name__=="__main__":
    main()